import pandas as pd
import numpy as np
from math import log
import operator

def loadData(filename):
    '''
    输入：文件
    输出：csv数据集
    '''
    dataset = pd.read_csv(filename)
    return dataset

def calcShannonEnt(dataset):
    '''
    输入：数据集
    输出：数据集的香农熵
    描述：计算给定数据集的香农熵
    '''
    numEntries = dataset.shape[0]  
    labelCounts = {} 
    cols = dataset.columns.tolist() 
    classlabel = dataset[cols[-1]].tolist() 
    for currentlabel in classlabel:
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 1
        else:
            labelCounts[currentlabel] += 1

    ShannonEnt = 0.0

    for key in labelCounts:
        prob = labelCounts[key]/numEntries
        ShannonEnt -= prob*log(prob, 2)

    return ShannonEnt

def splitDataSet(dataset, axis, value):
    '''
    输入：数据集，所占列，选择值
    输出：划分数据集
    描述：按照给定特征划分数据集；选择所占列中等于选择值的项
    '''
    cols = dataset.columns.tolist()
    axisFeat = dataset[axis].tolist()
    #更新数据集
    retDataSet = pd.concat([dataset[featVec] for featVec in cols if featVec != axis], axis=1)
    i = 0
    dropIndex = [] #删除项的索引集
    for featVec in axisFeat:
        if featVec != value:
            dropIndex.append(i)
            i += 1
        else:
            i += 1
    newDataSet = retDataSet.drop(dropIndex)
    return newDataSet.reset_index(drop=True)


def chooseBestFeatureToSplit(dataset):
    '''
    输入：数据集
    输出：最好的划分特征
    描述：选择最好的数据集划分维度
    '''
    numFeatures = dataset.shape[1] - 1
    ShannonEnt = calcShannonEnt(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    cols = dataset.columns.tolist()
    for i in range(numFeatures):
        equalVals = set(dataset[cols[i]].tolist())
        newEntropy = 0.0
        for value in equalVals:
            subDataSet = splitDataSet(dataset, cols[i], value)
            prob = subDataSet.shape[0] / dataset.shape[0]
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = ShannonEnt - newEntropy
        print(cols[i],infoGain)
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = cols[i]
    return bestFeature, bestInfoGain

def majorityCnt(classList):
    '''
    输入：分类类别列表
    输出：子节点的分类
    描述：数据集已经处理了所有属性，但是类标签依然不是唯一的，
          采用多数判决的方法决定该子节点的分类
    '''
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(), key= operator.itemgetter(1), reversed=True)
    return sortedClassCount[0][0]

def createTree(dataset, dropCol):
    '''
    输入：数据集，删除特征
    输出：决策树
    描述：递归构建决策树，利用上述的函数
    '''
    print(dataset.columns.tolist())
    cols = dataset.columns.tolist()[:-1]
    print(cols)
    classList = dataset[dataset.columns.tolist()[-1]].tolist()
    print(classList)

    #若数据集中所有实例属于同一类Ck，则为单节点树，并将Ck作为该节点的类标记
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    #若特征集为空集，则为单节点树，并将数据集中实例数最大的类Ck作为该节点的类标记
    if len(dataset[0:1]) == 0:
        return majorityCnt(classList)
    
    # dataset.drop(dropCol, axis=1, inplace=True)
    print('特征集和类别:',dataset.columns.tolist())
    bestFeature, bestInfoGain=chooseBestFeatureToSplit(dataset)
    print('bestFeture:',bestFeature)

    myTree = {bestFeature:{}}

    #del(labels[bestFeat])
    # 得到列表包括节点所有的属性值
    # print(bestFeature)
    featValues = dataset[bestFeature]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        myTree[bestFeature][value] = createTree(splitDataSet(dataset, bestFeature, value), bestFeature)
    return myTree

def main():
    filename = "../../data/test_id3.csv"
    dataset = loadData(filename)
    
    print(dataset.values)
    print("dataSet 数据：")
    
    print(dataset.columns.tolist() )
    dropCol = []
    myTree = createTree(dataset, dropCol)
    print(myTree)

if __name__ == '__main__':
    main()